{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B30hk2csRZ5F"
   },
   "source": [
    "## Using 🤗 Hugging Face Models with Tensorflow + TPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dvFEN4ogSWKS"
   },
   "source": [
    "Most of this notebook is designed to be run on a Colab TPU. To access TPU on Colab, go to `Runtime -> Change runtime type` and choose `TPU`. Some parts of the code may need to be changed when running on a Google Cloud TPU VM or TPU Node. We have indicated in the code where these changes may be necessary.\n",
    "\n",
    "At busy times, you may find that there's a lot of competition for TPUs and it can be hard to get access to a free one on Colab. Keep trying!\n",
    "\n",
    "This notebook is focused on usable code, but if you'd like a more high-level explanation of how to work with TPUs, please check out our [associated TPU tutorial](https://huggingface.co/docs/transformers/main/en/perf_train_tpu_tf)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1dCyzdc0TAEF"
   },
   "source": [
    "First, install up-to-date versions of `transformers` and `datasets` if you don't have them already."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jeF7wNTst-DP"
   },
   "outputs": [],
   "source": [
    "!pip install --upgrade transformers datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rZSNCjbiTJ0y"
   },
   "source": [
    "### Initialize your TPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L5XoRQLqTSA4"
   },
   "source": [
    "This next block will need to be modified depending on how you're accessing the TPU. For Colab, this code should work fine. When running on a TPU VM, pass the argument `tpu=\"local\"` to the `TPUClusterResolver`. When running on a non-Colab TPU Node, you'll need to pass the address of the TPU resource. When debugging on CPU/GPU, skip this block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mfP7ni6twlhu"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "resolver = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
    "# On TPU VMs use this line instead:\n",
    "# resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=\"local\")\n",
    "tf.config.experimental_connect_to_cluster(resolver)\n",
    "tf.tpu.experimental.initialize_tpu_system(resolver)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jVcUx-zUUKUc"
   },
   "source": [
    "### Prepare your `Strategy`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i6rjYRjBUM1Z"
   },
   "source": [
    "In TensorFlow, a `Strategy` object determines how models and data should be distributed across workers. There is a `TPUStrategy` specifically for TPU. However, when debugging, we recommend starting with the simplest `OneDeviceStrategy` to make sure your code works on CPU, and then swapping it for the `TPUStrategy` once you're sure it's bug-free."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0gI46C0lUJos"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "strategy = tf.distribute.TPUStrategy(resolver)\n",
    "# For testing without a TPU use this line instead:\n",
    "# strategy = tf.distribute.OneDeviceStrategy(\"/cpu:0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "iW-kSIUxXpRe"
   },
   "source": [
    "### Load and preprocess training data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bpMV3FjYX-Xs"
   },
   "source": [
    "In order for TPU training to work, you must create the model used for training inside the `Strategy.scope()`. However, other things like Hugging Face `tokenizers` and the `Dataset` do not need to be created in this scope.\n",
    "\n",
    "For this example we will use CoLA, which is a small and simple binary text classification dataset from the GLUE benchmark.\n",
    "\n",
    "We also pad all samples to the maximum length, firstly to make it easier to load them as an array, but secondly because this avoids issues with XLA later. For more information on XLA compilation and TPUs, see the [associated TPU tutorial](https://huggingface.co/docs/transformers/main/en/perf_train_tpu_tf)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HZL1IteEXwY0"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from datasets import load_dataset\n",
    "import numpy as np\n",
    "\n",
    "model_checkpoint = \"distilbert-base-cased\"\n",
    "dataset = load_dataset(\"glue\", \"cola\")[\"train\"]\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
    "\n",
    "# For simplicity, let's just tokenize our dataset as NumPy arrays\n",
    "# padded to the maximum length. We discuss other options below!\n",
    "train_data = tokenizer(\n",
    "    dataset[\"sentence\"],\n",
    "    padding=\"max_length\",\n",
    "    truncation=True,\n",
    "    max_length=128,\n",
    "    return_tensors=\"np\",\n",
    ")\n",
    "train_data = dict(train_data)  # Because the tokenizer returns a dict subclass\n",
    "train_labels = np.array(dataset[\"label\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OWZoq4MRqcUn"
   },
   "source": [
    "### Create your Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Ztn_wVn1rO56"
   },
   "source": [
    "While preprocessing data you can operate outside of the `strategy.scope()`, but model creation **must** take place inside it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UazsqDKKqhRw"
   },
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForSequenceClassification\n",
    "\n",
    "with strategy.scope():\n",
    "    model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n",
    "    # You can compile with jit_compile=True when debugging on CPU or GPU to check\n",
    "    # that XLA compilation works. Remember to take it out when actually running\n",
    "    # on TPU, though - XLA compilation will be handled for you when running with a\n",
    "    # TPUStrategy!\n",
    "    model.compile(optimizer=\"adam\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YfklfbsjY0hK"
   },
   "source": [
    "### Create the `tf.data.Dataset`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZzPCtUEydaoj"
   },
   "source": [
    "Keras methods like `fit()` can usually accept a broad range of inputs - a `list`/`tuple`/`dict` of `np.ndarray` or `tf.Tensor`, Python generators, `tf.data.Dataset`, and so on. **This is not the case on TPU.**\n",
    "\n",
    "On TPU, your input must always be a `tf.data.Dataset`. If you pass anything else\n",
    "to `model.fit()` when using a `TPUStrategy`, it will try to coerce it into a `tf.data.Dataset`. This sometimes works, but will create a lot of console spam and warnings even when it does. As a result, we recommend explicitly creating a `tf.data.Dataset` in all cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Xp7tmvy8uJ1f"
   },
   "outputs": [],
   "source": [
    "# The batch size will be split among TPU workers\n",
    "# so we scale it up based on how many of them there are\n",
    "BATCH_SIZE = 8 * strategy.num_replicas_in_sync\n",
    "\n",
    "tf_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels))\n",
    "tf_dataset = tf_dataset.shuffle(len(tf_dataset))\n",
    "# You should use drop_remainder on TPU where possible, because a change in the\n",
    "# batch size will require a new XLA compilation\n",
    "tf_dataset = tf_dataset.batch(BATCH_SIZE, drop_remainder=True)\n",
    "\n",
    "tf_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S-Km-TQjvy0v"
   },
   "source": [
    "### Train your model!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XuE4pwstxKwu"
   },
   "source": [
    "If you made it this far, then this next line should feel very familiar. Note that `fit()` doesn't actually need to be in the `scope()`, as long as the model and dataset were created there!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ONdYZNYzuTrH"
   },
   "outputs": [],
   "source": [
    "model.fit(tf_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t0iwTQtXluR7"
   },
   "source": [
    "And that's it! You just trained a Hugging Face model on TPU."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Om6K_JWP5wFO"
   },
   "source": [
    "## Advanced dataset creation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Q1AmBaLW5imV"
   },
   "source": [
    "Although the code above is perfectly usable, the dataset creation has been very simplified. We padded every sample to the maximum length in the whole dataset, and we also loaded the whole dataset into memory. When your data is too big for this to work, you will need to use a different approach instead.\n",
    "\n",
    "Below, we're going to list a few possible approaches to try. Note that some of these approaches may not work on Colab or TPU Node, so don't panic if you get errors! We'll try to indicate which code will work where, and what the advantages and disadvantages of each method are. When adapting this code for your own projects, we recommend choosing only one of these approaches, don't try to do them all at once!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xIhnkALTvDML"
   },
   "source": [
    "### Convert your data to `TFRecord`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7m-pR5QjvJ4y"
   },
   "source": [
    "`TFRecord` is the standard `tf.data` format for storing training data. For very large training jobs, it's often worth preprocessing your data and storing it all as TFRecord, then building your own `tf.data` pipeline on top of it. This is more work, and often requires you to pay for cloud storage, but it works for training on a wide range of devices (including TPU VM, TPU Node and Colab), and allows for truly massive data pipeline throughput.\n",
    "\n",
    "When converting to TFRecord, it's a good idea to do your preprocessing and tokenization before writing the TFRecord, so you don't have to do it every time the data is loaded. However, if you intend to use **train-time augmentations** you should be careful **not** to apply those before writing the `TFRecord`, or else you'll get exactly the same augmentation each epoch, which defeats the purpose of augmenting your data in the first place! Instead, you should apply augmentations in the `tf.data` pipeline that loads your data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wL4Vs_NdJm2n"
   },
   "source": [
    "First, we initialize our TPU. Skip this block if you're running on CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_46u-c2VJm2p"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "resolver = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
    "# On TPU VMs use this line instead:\n",
    "# resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=\"local\")\n",
    "tf.config.experimental_connect_to_cluster(resolver)\n",
    "tf.tpu.experimental.initialize_tpu_system(resolver)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ieJGCHYWJm2r"
   },
   "source": [
    "Next, we load our strategy, dataset, tokenizer and model just like we did in the first example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fprh0rJiJm2s"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from transformers import AutoTokenizer, TFAutoModelForSequenceClassification\n",
    "from datasets import load_dataset\n",
    "\n",
    "strategy = tf.distribute.TPUStrategy(resolver)\n",
    "# For testing without a TPU use this line instead:\n",
    "# strategy = tf.distribute.OneDeviceStrategy(\"/cpu:0\")\n",
    "\n",
    "BATCH_SIZE = 8 * strategy.num_replicas_in_sync\n",
    "\n",
    "dataset = load_dataset(\"glue\", \"cola\", split=\"train\")\n",
    "\n",
    "model_checkpoint = \"distilbert-base-cased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
    "\n",
    "with strategy.scope():\n",
    "    model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n",
    "    model.compile(optimizer=\"adam\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "PpwyCzJkJtF8"
   },
   "source": [
    "Now, let's tokenize our Hugging Face dataset.\n",
    "\n",
    "**Tip:** When using this method in practice, you probably won't be able to load your entire dataset in memory - instead, load a chunk of the dataset at a time and convert that to a TFRecord file, then repeat until you've covered the entire dataset, then use the list of all the files to create the TFRecordDataset later. In this example, we'll just create a single file for simplicity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "igz_ErzlWCz0"
   },
   "outputs": [],
   "source": [
    "tokenized_data = tokenizer(\n",
    "    dataset[\"sentence\"],\n",
    "    padding=\"max_length\",\n",
    "    truncation=True,\n",
    "    max_length=128,\n",
    "    return_tensors=\"np\",\n",
    ")\n",
    "labels = dataset[\"label\"]\n",
    "\n",
    "with tf.io.TFRecordWriter(\"dataset.tfrecords\") as file_writer:\n",
    "    for i in range(len(labels)):\n",
    "        features = {\n",
    "            \"input_ids\": tf.train.Feature(\n",
    "                int64_list=tf.train.Int64List(value=tokenized_data[\"input_ids\"][i])\n",
    "            ),\n",
    "            \"attention_mask\": tf.train.Feature(\n",
    "                int64_list=tf.train.Int64List(value=tokenized_data[\"attention_mask\"][i])\n",
    "            ),\n",
    "            \"labels\": tf.train.Feature(\n",
    "                int64_list=tf.train.Int64List(value=[labels[i]])\n",
    "            ),\n",
    "        }\n",
    "        features = tf.train.Features(feature=features)\n",
    "        example = tf.train.Example(features=features)\n",
    "        record_bytes = example.SerializeToString()\n",
    "        file_writer.write(record_bytes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0AcBCIuGgYk1"
   },
   "source": [
    "Now, to load the dataset we build a `TFRecordDataset` using the filenames of the file(s) we saved. Ordinarily, you would need to create your own bucket in Google Cloud Storage, upload files to there, and handle authenticating your Python code so it can access it. However, for the sake of this example, we have uploaded the example file to a public bucket for you, so you can get started quickly!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HzxZt2T1lp9B"
   },
   "outputs": [],
   "source": [
    "def decode_fn(sample):\n",
    "    features = {\n",
    "        \"input_ids\": tf.io.FixedLenFeature((128,), dtype=tf.int64),\n",
    "        \"attention_mask\": tf.io.FixedLenFeature((128,), dtype=tf.int64),\n",
    "        \"labels\": tf.io.FixedLenFeature((1,), dtype=tf.int64),\n",
    "    }\n",
    "    return tf.io.parse_example(sample, features)\n",
    "\n",
    "# TFRecordDataset can handle gs:// paths!\n",
    "tf_dataset = tf.data.TFRecordDataset([\"gs://matt-tf-tpu-tutorial-datasets/cola/dataset.tfrecords\"])\n",
    "tf_dataset = tf_dataset.map(decode_fn)\n",
    "tf_dataset = tf_dataset.shuffle(len(dataset)).batch(BATCH_SIZE, drop_remainder=True)\n",
    "tf_dataset = tf_dataset.apply(\n",
    "    tf.data.experimental.assert_cardinality(len(labels) // BATCH_SIZE)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "df1XRvk_myil"
   },
   "source": [
    "And now we can simply fit our dataset as before."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uAgMdM1rYuYy"
   },
   "outputs": [],
   "source": [
    "model.fit(tf_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WZIxiqSjq_0C"
   },
   "source": [
    "In summary:\n",
    "\n",
    "**TFRecord advantages:**\n",
    "- Works on all TPU instances (if the TFRecords are [stored in Google Cloud](https://www.tensorflow.org/api_docs/python/tf/io/gfile/GFile))\n",
    "- Can support huge datasets and massive throughput\n",
    "- Suitable for training on even entire TPU pods (!)\n",
    "- Preprocessing is done in advance, maximizing training speed\n",
    "\n",
    "**TFRecord disadvantages:**\n",
    "- Cloud storage isn't free!\n",
    "- Some datatypes (e.g. images) can take up a lot of space in this format"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "J7tx2EsTtGZr"
   },
   "source": [
    "### Stream from raw data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pz7Dku5BtL7O"
   },
   "source": [
    "In all of the examples above, we preprocessed data with a `tokenizer`, and then loaded the preprocessed data to fit our model. However, there is an alternate approach: The data can be stored in its native format, and the preprocessing can be done in the `tf.data` pipeline itself as the data is loaded!\n",
    "\n",
    "This is probably the most complex approach, but it can be useful if converting to `TFRecord` is difficult, such as when you don't want to save preprocessed images. It's especially useful when the dataset you want is already publicly available in cloud storage - this saves you having to create (and pay for!) your own cloud storage bucket.\n",
    "\n",
    "Many Hugging Face NLP models have complex tokenization schemes that are not yet supported as `tf.data` operations, and so this approach will not work for them. However, some (e.g. BERT) do have [fully TF compilable tokenization](https://huggingface.co/docs/transformers/model_doc/bert#transformers.TFBertTokenizer). This is often a great approach for image models, though!\n",
    "\n",
    "Let's see an example of this in action."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DD-7Nd4ZQ3RY"
   },
   "source": [
    "First, we initialize our TPU. Skip this block if you're running on CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GOqXA2PJQ3Ra"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "resolver = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
    "# On TPU VMs use this line instead:\n",
    "# resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=\"local\")\n",
    "tf.config.experimental_connect_to_cluster(resolver)\n",
    "tf.tpu.experimental.initialize_tpu_system(resolver)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Plea6qCOQ-q3"
   },
   "source": [
    "Next, we create our strategy as we did in the first example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PggnR5T5RIp3"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "strategy = tf.distribute.TPUStrategy(resolver)\n",
    "# For testing without a TPU use this line instead:\n",
    "# strategy = tf.distribute.OneDeviceStrategy(\"/cpu:0\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bDM7iIc2Q7fM"
   },
   "source": [
    "Next, let's download an image dataset. We'll use Hugging Face datasets for this, but you can use any other source too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "P6w5eK41uKu6"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "image_dataset = load_dataset(\"beans\", split=\"train\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vozfi3Xgzz8w"
   },
   "source": [
    "Now, let's get a list of the underlying image file paths and labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SMtZbeyGxECr"
   },
   "outputs": [],
   "source": [
    "filenames = image_dataset[\"image_file_path\"]\n",
    "labels = image_dataset[\"labels\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ordinarily at this point you would need to create your own bucket in Google Cloud Storage, upload the image files to there, and handle authenticating your Python code so it can access it. However, for the sake of this example we have uploaded the images to a public bucket for you, so you can get started quickly!\n",
    "\n",
    "We'll use this quick conversion below to turn the local filenames in the dataset into `gs://` paths in Google Cloud Storage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Strip everything but the category directory and filenames\n",
    "base_filenames = ['/'.join(filename.split('/')[-2:]) for filename in filenames]\n",
    "# Prepend the google cloud base path to everything instead\n",
    "gs_paths = [\"gs://matt-tf-tpu-tutorial-datasets/beans/\"+filename for filename in base_filenames]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_dataset = tf.data.Dataset.from_tensor_slices(\n",
    "    {\"filename\": gs_paths, \"labels\": labels}\n",
    ")\n",
    "tf_dataset = tf_dataset.shuffle(len(tf_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dW5jDak11VBG"
   },
   "source": [
    "That was pretty painless, but now we come to the tricky bit. It's extremely important to preprocess data in the way that the model expects. Classes like `AutoTokenizer` and `AutoImageProcessor` are designed to easily load the exact configuration for any model, so that you're guaranteed that your preprocessing will be correct.\n",
    "\n",
    "However, this might seem like a problem when we need to do the preprocessing in `tf.data`! These classes contain framework-agnostic code which `tf.data` will usually not be able to compile into a pipeline. Don't panic, though - for image datasets we can simply get the normalization values from those classes, and then use them in our `tf.data` pipeline.\n",
    "\n",
    "Let's use [ViT](https://huggingface.co/google/vit-base-patch16-224) as our image model, and get the `mean` and `std` values used to normalize images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eod4Z6wl1Uew"
   },
   "outputs": [],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "\n",
    "image_model_checkpoint = \"google/vit-base-patch16-224\"\n",
    "\n",
    "processor = AutoImageProcessor.from_pretrained(image_model_checkpoint)\n",
    "image_size = (processor.size[\"height\"], processor.size[\"width\"])\n",
    "image_mean = processor.image_mean\n",
    "image_std = processor.image_std"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "T8qVXeYD4_NV"
   },
   "source": [
    "Now we can write a function to load and preprocess the images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pEeonD69zvpM"
   },
   "outputs": [],
   "source": [
    "BATCH_SIZE = 8 * strategy.num_replicas_in_sync\n",
    "\n",
    "\n",
    "def decode_fn(sample):\n",
    "    image_data = tf.io.read_file(sample[\"filename\"])\n",
    "    image = tf.io.decode_jpeg(image_data, channels=3)\n",
    "    image = tf.image.resize(image, image_size)\n",
    "    array = tf.cast(image, tf.float32)\n",
    "    array /= 255.0\n",
    "    array = (array - image_mean) / image_std\n",
    "    array = tf.transpose(array, perm=[2, 0, 1])  # Swap to channels-first\n",
    "    return {\"pixel_values\": array, \"labels\": sample[\"labels\"]}\n",
    "\n",
    "tf_dataset = tf_dataset.map(decode_fn)\n",
    "tf_dataset = tf_dataset.batch(BATCH_SIZE, drop_remainder=True)\n",
    "print(tf_dataset.element_spec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bzJODzWU7rj5"
   },
   "source": [
    "Nice! Now we have a pipeline we can feed our model with. Let's try it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mlbRCqUJ06tm"
   },
   "outputs": [],
   "source": [
    "from transformers import TFAutoModelForImageClassification\n",
    "\n",
    "with strategy.scope():\n",
    "    model = TFAutoModelForImageClassification.from_pretrained(image_model_checkpoint)\n",
    "    model.compile(optimizer=\"adam\")\n",
    "\n",
    "model.fit(tf_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "MSpj5NdPBWL5"
   },
   "source": [
    "In summary:\n",
    "\n",
    "**tf.data pipeline advantages:**\n",
    "- Very suitable for big data that is highly compressed in its native format (images, audio)\n",
    "- Very convenient if the raw data is already available in a public cloud bucket\n",
    "- Works on all TPU instances (if the data is stored in Google Cloud)\n",
    "\n",
    "**tf.data pipeline disadvantages:**\n",
    "- You'll need to write a full preprocessing pipeline\n",
    "- If preprocessing is complex, doing it on-the-fly can hurt throughput\n",
    "- If the data isn't already available in cloud storage you'll have to put it there\n",
    "- Less suitable for text data because writing a tokenization pipeline is hard, plus tokenized text is small and suitable for TFRecord"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L9IijL2AnCi3"
   },
   "source": [
    "### Stream from your dataset with `model.prepare_tf_dataset()`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DtJMrPVonXHd"
   },
   "source": [
    "If you've read any of our other [example notebooks](https://huggingface.co/docs/transformers/notebooks), you'll notice we often use the method `Dataset.to_tf_dataset()` or its higher-level wrapper `model.prepare_tf_dataset()` to convert Hugging Face Datasets to `tf.data.Dataset`. These methods can work for TPU, but with several caveats!\n",
    "\n",
    "The main thing to know is that these methods do not actually convert the entire Hugging Face `Dataset`. Instead, they create a `tf.data` pipeline that loads samples from the `Dataset`. This pipeline uses `tf.numpy_function` or `Dataset.from_generator()` to access the underlying `Dataset`, and as a result the whole pipeline cannot be compiled by TensorFlow. **Because of this, and because the pipeline streams from data on a local disc, these methods will not work on Colab TPU or TPU Nodes.**\n",
    "\n",
    "However, if you're running on a TPU VM and you can tolerate TensorFlow throwing some warnings, this method can work! Let's see it in action. By default, the code below will run on CPU so you can try it on Colab, but if you have a TPU VM feel free to try running it on TPU there."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-_uMLSSiHRU2"
   },
   "source": [
    "First, we initialize our TPU. Skip this block if you're running on CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Xvc_wpu0HiJJ"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "resolver = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
    "# On TPU VMs use this line instead:\n",
    "# resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=\"local\")\n",
    "tf.config.experimental_connect_to_cluster(resolver)\n",
    "tf.tpu.experimental.initialize_tpu_system(resolver)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "LlFyGsACHri-"
   },
   "source": [
    "Next, we load our strategy, dataset, tokenizer and model just like we did in the first example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OzkOZJZkH0ek"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from transformers import AutoTokenizer, TFAutoModelForSequenceClassification\n",
    "from datasets import load_dataset\n",
    "\n",
    "# By default, we run on CPU so you can try this code on Colab\n",
    "strategy = tf.distribute.OneDeviceStrategy(\"/cpu:0\")\n",
    "# When actually running on a TPU VM use this line instead:\n",
    "# strategy = tf.distribute.TPUStrategy(resolver)\n",
    "\n",
    "BATCH_SIZE = 8 * strategy.num_replicas_in_sync\n",
    "\n",
    "dataset = load_dataset(\"glue\", \"cola\", split=\"train\")\n",
    "\n",
    "model_checkpoint = \"distilbert-base-cased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
    "\n",
    "with strategy.scope():\n",
    "    model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint)\n",
    "    model.compile(optimizer=\"adam\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pl6J1EA5InGB"
   },
   "source": [
    "Next, we add the tokenizer output as columns in the dataset. Since the dataset is stored on disc, this means we can handle data much bigger than our available memory. Once that's done, we can use `prepare_tf_dataset` to stream data from the Hugging Face Dataset by wrapping it with a `tf.data` pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DeoeHixsnGip"
   },
   "outputs": [],
   "source": [
    "def tokenize_function(examples):\n",
    "    return tokenizer(\n",
    "        examples[\"sentence\"], padding=\"max_length\", truncation=True, max_length=128\n",
    "    )\n",
    "\n",
    "\n",
    "# This will add the tokenizer output to the dataset as new columns\n",
    "dataset = dataset.map(tokenize_function)\n",
    "\n",
    "# prepare_tf_dataset() will choose columns that match the model's input names\n",
    "tf_dataset = model.prepare_tf_dataset(\n",
    "    dataset, batch_size=BATCH_SIZE, shuffle=True, tokenizer=tokenizer\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7T6EeVG2t1N9"
   },
   "source": [
    "And now you can fit this dataset just like before!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7AVwd5UpsYt-"
   },
   "outputs": [],
   "source": [
    "model.fit(tf_dataset)  # Note - will be very slow if you're on CPU"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-pAuOHsvt75z"
   },
   "source": [
    "In summary:\n",
    "\n",
    "**prepare_tf_dataset() advantages:**\n",
    "- Simple code\n",
    "- Same approach works on TPU and GPU\n",
    "- Dataset doesn't have to fit in memory\n",
    "- Can support variable rather than constant padding\n",
    "\n",
    "**prepare_tf_dataset() disadvantages:**\n",
    "- Only works on TPU VM, not on TPU Node/Colab\n",
    "- Data must be available as a Hugging Face Dataset\n",
    "- Data must fit on local storage\n",
    "- If you're using a big TPU pod slice, data loading may be a bottleneck\n",
    "- TensorFlow will yell at you a bit\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "rZSNCjbiTJ0y",
    "jVcUx-zUUKUc",
    "iW-kSIUxXpRe",
    "OWZoq4MRqcUn",
    "S-Km-TQjvy0v",
    "L9IijL2AnCi3",
    "xIhnkALTvDML",
    "J7tx2EsTtGZr"
   ],
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
